/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <thrift/thrift-config.h>

#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <string>

#include <thrift/concurrency/IOService.h>
#include <thrift/transport/TSocket.h>
#include <thrift/transport/TTransportException.h>

#include <boost/system/system_error.hpp>
#include <boost/asio/write.hpp>

namespace apache {
namespace thrift {
namespace transport {

using namespace std;
using namespace apache::thrift::concurrency;
using boost::asio::deadline_timer;
using boost::asio::ip::tcp;

/**
 * TSocket implementation.
 *
 */

TSocket::TSocket(const string& host, int port)
  : host_(host)
  , port_(port)
  , socket_(IOService::getDefaultIOService())
  , deadline_(IOService::getDefaultIOService())
  , connTimeout_(0)
  , sendTimeout_(0)
  , recvTimeout_(0)
  , keepAlive_(false)
  , lingerOn_(1)
  , lingerVal_(0)
  , noDelay_(1)
  , maxRecvRetries_(5)
{

}

TSocket::TSocket()
  : host_("")
  , port_(0)
  , socket_(IOService::getDefaultIOService())
  , deadline_(IOService::getDefaultIOService())
  , connTimeout_(0)
  , sendTimeout_(0)
  , recvTimeout_(0)
  , keepAlive_(false)
  , lingerOn_(1)
  , lingerVal_(0)
  , noDelay_(1)
  , maxRecvRetries_(5)
{
  
}

TSocket::~TSocket()
{
  close();
}

bool TSocket::isOpen()
{
  return socket_.is_open();
}

bool TSocket::peek()
{
  if (!isOpen()) {
    return false;
  }
  return true;
}

struct connect_handler
{
  connect_handler(const boost::shared_ptr<boost::system::error_code>& ec,
                  const boost::shared_ptr<boost::mutex>& mux,
                  const boost::shared_ptr<boost::condition_variable>& cnd)
    : error(ec)
    , mutx(mux)
    , cond(cnd)
    {

    }

  void operator () (const boost::system::error_code& ec, tcp::resolver::iterator next)
  {
    boost::unique_lock<boost::mutex> lock(*mutx);
    *error = ec;
    lock.unlock();
    cond->notify_one();
  }

  void operator () (const boost::system::error_code& ec)
  {
    if (!ec) {
      boost::unique_lock<boost::mutex> lock(*mutx);
      *error = boost::asio::error::timed_out;
      lock.unlock();
      cond->notify_one();
    }
  }

  boost::shared_ptr<boost::system::error_code> error;
  boost::shared_ptr<boost::mutex> mutx;
  boost::shared_ptr<boost::condition_variable> cond;
};

void TSocket::openConnection()
{
  if (sendTimeout_ > 0) {
    setSendTimeout(sendTimeout_);
  }

  if (recvTimeout_ > 0) {
    setRecvTimeout(recvTimeout_);
  }

  if (keepAlive_) {
    setKeepAlive(keepAlive_);
  }

  setLinger(lingerOn_, lingerVal_);

  setNoDelay(noDelay_);


  tcp::resolver::query query(host_, apache::thrift::to_string(port_));
  tcp::resolver::iterator iter = tcp::resolver(IOService::getDefaultIOService()).resolve(query);

  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  connect_handler conn_handler(ec, mux, cond);
  boost::asio::async_connect(socket_, iter, conn_handler);

  if (connTimeout_ > 0) {
      deadline_.expires_from_now(boost::posix_time::milliseconds(connTimeout_));
      deadline_.async_wait(conn_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) { 
    cond->wait(lock);
  }
  
  if (*ec) {
    throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", ec->value());
  }
  else {
    deadline_.cancel();
  }
}

void TSocket::open()
{
  if (isOpen()) {
    return;
  }
 
  // Validate port number
  if (port_ < 0 || port_ > 0xFFFF) {
    throw TTransportException(TTransportException::NOT_OPEN, "Specified port is invalid");
  }

  openConnection();
}

void TSocket::close()
{
  using namespace boost::asio;
  if (socket_.is_open()) {
    boost::system::error_code ec;
    socket_.shutdown(socket_base::shutdown_both, ec);
    socket_.close(ec);
  }
}

struct iohandler
{
  iohandler(const boost::shared_ptr<boost::system::error_code>& ec,
            const boost::shared_ptr<std::size_t>& length,
            const boost::shared_ptr<boost::mutex>& mux,
            const boost::shared_ptr<boost::condition_variable>& cnd)
    : error(ec)
    , bytes_transferred(length)
    , mutx(mux)
    , cond(cnd)
    {

    }

  void operator() (const boost::system::error_code& ec, std::size_t length)
  {
    boost::unique_lock<boost::mutex> lock(*mutx);
    if (ec) {
      *error = ec;
      lock.unlock();
      cond->notify_one();
    }
    else {
      *error = ec;
      *bytes_transferred += length;
      lock.unlock();
      cond->notify_one();
    }
  }

  void operator() (const boost::system::error_code& ec)
  {
    if (!ec) {
      boost::unique_lock<boost::mutex> lock(*mutx);
      *error = boost::asio::error::timed_out;
      lock.unlock();
      cond->notify_one();
    }
  }

  boost::shared_ptr<boost::system::error_code> error;
  boost::shared_ptr<std::size_t> bytes_transferred;
  boost::shared_ptr<boost::mutex> mutx;
  boost::shared_ptr<boost::condition_variable> cond;
};

uint32_t TSocket::read(uint8_t* buf, uint32_t len)
{
  if (isOpen() == false) {
    throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
  }

  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  boost::shared_ptr<size_t> length(new size_t(0));
  iohandler read_handler(ec, length, mux, cond);

  socket_.async_read_some(boost::asio::buffer(buf, len), read_handler);

  if (recvTimeout_) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(recvTimeout_));
    deadline_.async_wait(read_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    if (*length > 0) {
      break;
    }
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_EAGAIN (timed out)");
  }
  else {
    deadline_.cancel();
  }

  // Pack data into string
  return (uint32_t)(*length);
}

uint32_t TSocket::readAll(uint8_t* buf, uint32_t len)
{
  if (isOpen() == false) {
    throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
  }
  
  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  boost::shared_ptr<size_t> length(new size_t(0));
  iohandler read_handler(ec, length, mux, cond);

  boost::asio::async_read(socket_, boost::asio::buffer(buf, len), read_handler);

  if (recvTimeout_) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(recvTimeout_));
    deadline_.async_wait(read_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    if (*length == len) {
      break;
    }
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_EAGAIN (timed out)");
  }
  else {
    deadline_.cancel();
  }

  // Pack data into string
  return (uint32_t)(*length);
}

void TSocket::write(const uint8_t* buf, uint32_t len) 
{
  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  boost::shared_ptr<size_t> length(new size_t(0));
  iohandler write_handler(ec, length, mux, cond);
  
  boost::asio::async_write(socket_, boost::asio::buffer(buf, len), write_handler);
  
  if (sendTimeout_ > 0) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(sendTimeout_));
    deadline_.async_wait(write_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    if (*length == len) {
      break;
    }
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::TIMED_OUT, "send timeout expired");
  }
  else {
    deadline_.cancel();
  }
}

std::string TSocket::getHost()
{
  return host_;
}

int TSocket::getPort()
{
  return port_;
}

void TSocket::setHost(string host)
{
  host_ = host;
}

void TSocket::setPort(int port)
{
  port_ = port;
}

void TSocket::setLinger(bool on, int linger)
{
  lingerOn_ = on;
  lingerVal_ = linger;
  if (isOpen() == false) {
    return;
  }
  boost::asio::socket_base::linger option(on, linger);
  socket_.set_option(option);
}

void TSocket::setNoDelay(bool noDelay)
{
  noDelay_ = noDelay;
  if (socket_.is_open() == false) {
    return;
  }

  boost::asio::ip::tcp::no_delay option(noDelay);
  socket_.set_option(option);
}

void TSocket::setConnTimeout(int ms)
{
  connTimeout_ = ms;
}


void TSocket::setRecvTimeout(int ms)
{
  recvTimeout_ = ms;
}

void TSocket::setSendTimeout(int ms)
{
  sendTimeout_ = ms;
}

void TSocket::setKeepAlive(bool keepAlive)
{
  keepAlive_ = keepAlive;

  if (socket_.is_open() == false) {
    return;
  }

  boost::asio::socket_base::keep_alive option(keepAlive);
  socket_.set_option(option);
}

void TSocket::setMaxRecvRetries(int maxRecvRetries)
{
  maxRecvRetries_ = maxRecvRetries;
}

string TSocket::getSocketInfo()
{
  std::ostringstream oss;
  if (host_.empty() || port_ == 0) {
    oss << "<Host: " << getPeerAddress();
    oss << " Port: " << getPeerPort() << ">";
  }
  else {
    oss << "<Host: " << host_ << " Port: " << port_ << ">";
  }
  return oss.str();
}

std::string TSocket::getPeerHost()
{
  if (peerHost_.empty()) {
   
    if (socket_.is_open() == false) {
      return host_;
    }
  boost::system::error_code ec;
  boost::asio::ip::tcp::endpoint endpoint = socket_.remote_endpoint(ec);
  if (!ec) {
    peerHost_ = endpoint.address().to_string();
    peerPort_ = endpoint.port();
    peerAddress_ = peerHost_;
  }
    
  }
  return peerHost_;
}

std::string TSocket::getPeerAddress()
{
  if (peerHost_.empty()) {
    if (socket_.is_open() == false) {
      return peerAddress_;
    }
    boost::system::error_code ec;
    boost::asio::ip::tcp::endpoint endpoint = socket_.remote_endpoint(ec);
    if (!ec) {
      peerHost_ = endpoint.address().to_string();
      peerPort_ = endpoint.port();
      peerAddress_ = peerHost_;
    }
  }
  return peerAddress_;
}

int TSocket::getPeerPort()
{
  getPeerAddress();
  return peerPort_;
}

const std::string TSocket::getOrigin()
{
  std::ostringstream oss;
  oss << getPeerHost() << ":" << getPeerPort();
  return oss.str();
}

}
}
} // apache::thrift::transport
